{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "JAX in Action - Chapter 4 - Different ways of getting derivatives",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Different ways of getting derivatives"
      ],
      "metadata": {
        "id": "09pyk3IhMw0c"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "A mathematical simple function we will use in our examples:"
      ],
      "metadata": {
        "id": "1rKjlHKqMaH8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def f(x):\n",
        "    return x**4 + 12*x + 1/x"
      ],
      "metadata": {
        "id": "OZprmjL1MXQ8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Manual differentiation"
      ],
      "metadata": {
        "id": "dmxr4Jt3M7lU"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "A closed-form expression for a derivative calculated manually:"
      ],
      "metadata": {
        "id": "Ri0HlD2pNQpC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def df(x):\n",
        "    return 4*x**3 + 12 - 1/x**2"
      ],
      "metadata": {
        "id": "R1b5BYYVMhno"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x = 11.0"
      ],
      "metadata": {
        "id": "R4M_FF6qMtTC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(f(x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zAFml-L8Pp3G",
        "outputId": "ffa41e85-c8c0-4fed-f7ff-be06aa3e9bd3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "14773.09090909091\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(df(x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h9ja42mPONC4",
        "outputId": "b2523e6d-9c96-490c-ea33-05a0d09135e0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5335.99173553719\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Symbolic differentiation"
      ],
      "metadata": {
        "id": "DAjk74l_RDeP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import sympy"
      ],
      "metadata": {
        "id": "xgmajTIQRGkj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x_sym = sympy.symbols('x')"
      ],
      "metadata": {
        "id": "0JgYe_qCRYjd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "f_sym = f(x_sym)\n",
        "df_sym = sympy.diff(f_sym)"
      ],
      "metadata": {
        "id": "euk3sncgSVd5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(f_sym)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6tloMKo8Sk8q",
        "outputId": "6fe97926-0ff3-4743-edd0-817bcd07978d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "x**4 + 12*x + 1/x\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(df_sym)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BC979fSWTSs2",
        "outputId": "85b7691d-6aa5-4dfb-f687-f31386d88845"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "4*x**3 + 12 - 1/x**2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "f_sym.evalf(subs={x: x})"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "id": "R-sBLkx9TYEW",
        "outputId": "9bbcb57c-3c57-4043-de63-5dbe5669ce90"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "x**4 + 12.0*x + 1/x"
            ],
            "text/latex": "$\\displaystyle x^{4} + 12.0 x + \\frac{1}{x}$"
          },
          "metadata": {},
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "f = sympy.lambdify(x_sym, f_sym)\n",
        "print(f(x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "D5cANcsWVrWg",
        "outputId": "332ab496-b8ad-4807-8cd6-fa163b0986c3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "14773.09090909091\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "df = sympy.lambdify(x_sym, df_sym)\n",
        "print(df(x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lBnfev7WTflH",
        "outputId": "12fd2b09-c114-4216-a838-21dc3c57c2fe"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5335.99173553719\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Numeric differentiation"
      ],
      "metadata": {
        "id": "96pn0MCNFEWk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "x = 11.0\n",
        "dx = 1e-6\n"
      ],
      "metadata": {
        "id": "tje5rh6_FGO-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_x_numeric = (f(x+dx)-f(x))/dx\n",
        "\n",
        "print(df_x_numeric)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nc85mBq_FK_Z",
        "outputId": "9d6d8f40-7abf-4a71-c754-f823054f67db"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5335.992456821259\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Automatic differentiation"
      ],
      "metadata": {
        "id": "3tAsjSvRGBke"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import jax"
      ],
      "metadata": {
        "id": "zeEW8C9oWLoe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df = jax.grad(f)"
      ],
      "metadata": {
        "id": "-vZjFOruYlt3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(df(x))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qXZNL9I0GHRf",
        "outputId": "0b631651-cbc4-426b-e454-f3c193412d95"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5335.9917\n"
          ]
        }
      ]
    }
  ]
}